{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "979a6cef",
   "metadata": {},
   "source": [
    "这里我们使用VGGNet测试Transfer Learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9b0db484",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Use device:  cuda\n"
     ]
    }
   ],
   "source": [
    "from hdd.device.utils import get_device\n",
    "from hdd.dataset.imagenette_in_memory import ImagenetteInMemory\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "# 设置训练数据的路径\n",
    "DATA_ROOT = \"~/workspace/hands-dirty-on-dl/dataset\"\n",
    "# 设置TensorBoard的路径\n",
    "TENSORBOARD_ROOT = \"~/workspace/hands-dirty-on-dl/dataset\"\n",
    "# 设置预训练模型参数路径\n",
    "TORCH_HUB_PATH = \"~/workspace/hands-dirty-on-dl/pretrained_models\"\n",
    "torch.hub.set_dir(TORCH_HUB_PATH)\n",
    "# 挑选最合适的训练设备\n",
    "DEVICE = get_device([\"cuda\", \"cpu\"])\n",
    "print(\"Use device: \", DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be88c698",
   "metadata": {},
   "source": [
    "加载VGG模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8a88436e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VGG(\n",
      "  (features): Sequential(\n",
      "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (2): ReLU(inplace=True)\n",
      "    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (5): ReLU(inplace=True)\n",
      "    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (9): ReLU(inplace=True)\n",
      "    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (12): ReLU(inplace=True)\n",
      "    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (16): ReLU(inplace=True)\n",
      "    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (19): ReLU(inplace=True)\n",
      "    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (22): ReLU(inplace=True)\n",
      "    (23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (25): ReLU(inplace=True)\n",
      "    (26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (27): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (29): ReLU(inplace=True)\n",
      "    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (32): ReLU(inplace=True)\n",
      "    (33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (35): ReLU(inplace=True)\n",
      "    (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (38): ReLU(inplace=True)\n",
      "    (39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (42): ReLU(inplace=True)\n",
      "    (43): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (45): ReLU(inplace=True)\n",
      "    (46): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (48): ReLU(inplace=True)\n",
      "    (49): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (51): ReLU(inplace=True)\n",
      "    (52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  )\n",
      "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
      "  (classifier): Sequential(\n",
      "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
      "    (1): ReLU(inplace=True)\n",
      "    (2): Dropout(p=0.5, inplace=False)\n",
      "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
      "    (4): ReLU(inplace=True)\n",
      "    (5): Dropout(p=0.5, inplace=False)\n",
      "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "from torchvision.models import vgg19_bn, VGG19_BN_Weights\n",
    "from torchsummary import summary\n",
    "\n",
    "model = vgg19_bn(weights=VGG19_BN_Weights.IMAGENET1K_V1)\n",
    "model = model.to(DEVICE)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "899c5ef7",
   "metadata": {},
   "source": [
    "创建与VGG预训练模型匹配的Data Transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a436856e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms._presets import ImageClassification\n",
    "from hdd.data_util.transforms import RandomMetaTransform\n",
    "from hdd.dataset.imagenette_in_memory import get_imagenette_label_to_imagenet_label\n",
    "\n",
    "train_transform = RandomMetaTransform(\n",
    "    ImageClassification(crop_size=224, resize_size=224),\n",
    "    ImageClassification(crop_size=224, resize_size=238),\n",
    "    ImageClassification(crop_size=224, resize_size=296),\n",
    ")\n",
    "val_transform = VGG19_BN_Weights.IMAGENET1K_V1.transforms()\n",
    "BATCH_SIZE = 32\n",
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    ImagenetteInMemory(\n",
    "        root=DATA_ROOT,\n",
    "        split=\"train\",\n",
    "        size=\"full\",\n",
    "        download=True,\n",
    "        transform=train_transform,\n",
    "    ),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    num_workers=8,\n",
    "    pin_memory=True,\n",
    ")\n",
    "val_dataloader = torch.utils.data.DataLoader(\n",
    "    ImagenetteInMemory(\n",
    "        root=DATA_ROOT,\n",
    "        split=\"val\",\n",
    "        size=\"full\",\n",
    "        download=True,\n",
    "        transform=val_transform,\n",
    "    ),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=False,\n",
    "    num_workers=8,\n",
    "    pin_memory=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "670fa6be",
   "metadata": {},
   "source": [
    "#### 在没有任何Fine Tuning的情况下,准确率为**85.78%**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cba078f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy without any fine tuning: 0.8578343949044586\n"
     ]
    }
   ],
   "source": [
    "from hdd.train.classification_utils import eval_image_classifier\n",
    "\n",
    "imagenette_to_imagenet = get_imagenette_label_to_imagenet_label()\n",
    "eval_result = eval_image_classifier(model, val_dataloader.dataset, DEVICE)\n",
    "ss = [\n",
    "    imagenette_to_imagenet[result.gt_label] == result.predicted_label\n",
    "    for result in eval_result\n",
    "]\n",
    "print(f\"Accuracy without any fine tuning: {sum(ss) / len(ss)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "164a0501",
   "metadata": {},
   "source": [
    "#### 微调Classifier,准确率为**98.55%**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4aa80792",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1/10 Train Loss: 0.4495 Accuracy: 0.8852 Time: 14.13914  | Val Loss: 0.0627 Accuracy: 0.9801\n",
      "Epoch: 2/10 Train Loss: 0.0550 Accuracy: 0.9845 Time: 14.06363  | Val Loss: 0.0551 Accuracy: 0.9837\n",
      "Epoch: 3/10 Train Loss: 0.0304 Accuracy: 0.9908 Time: 14.07111  | Val Loss: 0.0491 Accuracy: 0.9834\n",
      "Epoch: 4/10 Train Loss: 0.0180 Accuracy: 0.9958 Time: 14.10516  | Val Loss: 0.0491 Accuracy: 0.9862\n",
      "Epoch: 5/10 Train Loss: 0.0120 Accuracy: 0.9968 Time: 13.99336  | Val Loss: 0.0477 Accuracy: 0.9852\n",
      "Epoch: 6/10 Train Loss: 0.0095 Accuracy: 0.9975 Time: 14.03133  | Val Loss: 0.0543 Accuracy: 0.9850\n",
      "Epoch: 7/10 Train Loss: 0.0110 Accuracy: 0.9969 Time: 13.93512  | Val Loss: 0.0650 Accuracy: 0.9822\n",
      "Epoch: 8/10 Train Loss: 0.0069 Accuracy: 0.9983 Time: 13.95940  | Val Loss: 0.0548 Accuracy: 0.9847\n",
      "Epoch: 9/10 Train Loss: 0.0063 Accuracy: 0.9981 Time: 13.81623  | Val Loss: 0.0556 Accuracy: 0.9860\n",
      "Epoch: 10/10 Train Loss: 0.0037 Accuracy: 0.9992 Time: 13.85679  | Val Loss: 0.0519 Accuracy: 0.9855\n"
     ]
    }
   ],
   "source": [
    "from hdd.train.early_stopping import EarlyStoppingInMem\n",
    "from hdd.train.classification_utils import naive_train_classification_model\n",
    "\n",
    "model = vgg19_bn(weights=VGG19_BN_Weights.IMAGENET1K_V1)\n",
    "model.features.requires_grad_(False)\n",
    "dropout = 0.5\n",
    "num_classes = 10\n",
    "# 仅修改Classifier部分\n",
    "model.classifier = nn.Sequential(\n",
    "    nn.Linear(512 * 7 * 7, 4096),\n",
    "    nn.ReLU(True),\n",
    "    nn.Dropout(p=dropout),\n",
    "    nn.Linear(4096, 4096),\n",
    "    nn.ReLU(True),\n",
    "    nn.Dropout(p=dropout),\n",
    "    nn.Linear(4096, num_classes),\n",
    ")\n",
    "model = model.to(DEVICE)\n",
    "\n",
    "criteria = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)\n",
    "max_epochs = 10\n",
    "_ = naive_train_classification_model(\n",
    "    model,\n",
    "    criteria,\n",
    "    max_epochs,\n",
    "    train_dataloader,\n",
    "    val_dataloader,\n",
    "    DEVICE,\n",
    "    optimizer,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5200fad1",
   "metadata": {},
   "source": [
    "#### 微调全部参数,准确率为**97.86%**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1ae7b7df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1/20 Train Loss: 0.4179 Accuracy: 0.8835 Time: 36.22130  | Val Loss: 0.1110 Accuracy: 0.9651\n",
      "Epoch: 2/20 Train Loss: 0.0698 Accuracy: 0.9789 Time: 35.92898  | Val Loss: 0.0908 Accuracy: 0.9715\n",
      "Epoch: 3/20 Train Loss: 0.0466 Accuracy: 0.9833 Time: 35.94310  | Val Loss: 0.1304 Accuracy: 0.9638\n",
      "Epoch: 4/20 Train Loss: 0.0339 Accuracy: 0.9893 Time: 36.14090  | Val Loss: 0.1331 Accuracy: 0.9679\n",
      "Epoch: 5/20 Train Loss: 0.0328 Accuracy: 0.9905 Time: 36.03434  | Val Loss: 0.1596 Accuracy: 0.9557\n",
      "Epoch: 6/20 Train Loss: 0.0340 Accuracy: 0.9894 Time: 35.97279  | Val Loss: 0.1985 Accuracy: 0.9483\n",
      "Epoch: 7/20 Train Loss: 0.0260 Accuracy: 0.9922 Time: 36.00828  | Val Loss: 0.1182 Accuracy: 0.9699\n",
      "Epoch: 8/20 Train Loss: 0.0290 Accuracy: 0.9923 Time: 36.04504  | Val Loss: 0.1182 Accuracy: 0.9674\n",
      "Epoch: 9/20 Train Loss: 0.0201 Accuracy: 0.9942 Time: 36.17785  | Val Loss: 0.1121 Accuracy: 0.9694\n",
      "Epoch: 10/20 Train Loss: 0.0143 Accuracy: 0.9959 Time: 36.29167  | Val Loss: 0.1257 Accuracy: 0.9707\n",
      "Epoch: 11/20 Train Loss: 0.0077 Accuracy: 0.9977 Time: 36.24832  | Val Loss: 0.0931 Accuracy: 0.9763\n",
      "Epoch: 12/20 Train Loss: 0.0034 Accuracy: 0.9994 Time: 36.53615  | Val Loss: 0.0888 Accuracy: 0.9761\n",
      "Epoch: 13/20 Train Loss: 0.0028 Accuracy: 0.9994 Time: 35.54069  | Val Loss: 0.0873 Accuracy: 0.9766\n",
      "Epoch: 14/20 Train Loss: 0.0020 Accuracy: 0.9996 Time: 36.38922  | Val Loss: 0.0884 Accuracy: 0.9771\n",
      "Epoch: 15/20 Train Loss: 0.0027 Accuracy: 0.9990 Time: 36.43815  | Val Loss: 0.0884 Accuracy: 0.9781\n",
      "Epoch: 16/20 Train Loss: 0.0015 Accuracy: 0.9999 Time: 36.45486  | Val Loss: 0.0875 Accuracy: 0.9778\n",
      "Epoch: 17/20 Train Loss: 0.0019 Accuracy: 0.9994 Time: 36.33758  | Val Loss: 0.0856 Accuracy: 0.9791\n",
      "Epoch: 18/20 Train Loss: 0.0015 Accuracy: 0.9997 Time: 36.50615  | Val Loss: 0.0861 Accuracy: 0.9794\n",
      "Epoch: 19/20 Train Loss: 0.0019 Accuracy: 0.9995 Time: 36.64361  | Val Loss: 0.0881 Accuracy: 0.9791\n",
      "Epoch: 20/20 Train Loss: 0.0013 Accuracy: 0.9997 Time: 36.53087  | Val Loss: 0.0871 Accuracy: 0.9786\n"
     ]
    }
   ],
   "source": [
    "model = vgg19_bn(weights=VGG19_BN_Weights.IMAGENET1K_V1)\n",
    "dropout = 0.5\n",
    "num_classes = 10\n",
    "model.classifier = nn.Sequential(\n",
    "    nn.Linear(512 * 7 * 7, 4096),\n",
    "    nn.ReLU(True),\n",
    "    nn.Dropout(p=dropout),\n",
    "    nn.Linear(4096, 4096),\n",
    "    nn.ReLU(True),\n",
    "    nn.Dropout(p=dropout),\n",
    "    nn.Linear(4096, num_classes),\n",
    ")\n",
    "model = model.to(DEVICE)\n",
    "\n",
    "criteria = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(\n",
    "    optimizer, step_size=10, gamma=0.1, last_epoch=-1\n",
    ")\n",
    "max_epochs = 20\n",
    "_ = naive_train_classification_model(\n",
    "    model,\n",
    "    criteria,\n",
    "    max_epochs,\n",
    "    train_dataloader,\n",
    "    val_dataloader,\n",
    "    DEVICE,\n",
    "    optimizer,\n",
    "    scheduler,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a90f779",
   "metadata": {},
   "source": [
    "#### 从头训练,准确率为**85%**\n",
    "\n",
    "注意,我们并没有仔细调整相关的超参数,所以结果看起来有些糟糕,根据[VGGNet.ipyth](./VGGNet.ipynb),应该可以达到**91%**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1fcf5303",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1/150 Train Loss: 2.5135 Accuracy: 0.1973 Time: 36.43307  | Val Loss: 2.0769 Accuracy: 0.2456\n",
      "Epoch: 2/150 Train Loss: 1.9793 Accuracy: 0.3075 Time: 35.91563  | Val Loss: 1.8639 Accuracy: 0.3811\n",
      "Epoch: 3/150 Train Loss: 1.7788 Accuracy: 0.3878 Time: 36.59997  | Val Loss: 1.6883 Accuracy: 0.4382\n",
      "Epoch: 4/150 Train Loss: 1.5756 Accuracy: 0.4730 Time: 36.77549  | Val Loss: 1.4284 Accuracy: 0.5231\n",
      "Epoch: 5/150 Train Loss: 1.3437 Accuracy: 0.5653 Time: 36.72153  | Val Loss: 1.2293 Accuracy: 0.6005\n",
      "Epoch: 6/150 Train Loss: 1.1960 Accuracy: 0.6163 Time: 37.17435  | Val Loss: 1.3629 Accuracy: 0.5837\n",
      "Epoch: 7/150 Train Loss: 1.0645 Accuracy: 0.6608 Time: 36.62234  | Val Loss: 1.0919 Accuracy: 0.6428\n",
      "Epoch: 8/150 Train Loss: 0.9639 Accuracy: 0.6947 Time: 36.77001  | Val Loss: 1.0148 Accuracy: 0.6884\n",
      "Epoch: 9/150 Train Loss: 0.9033 Accuracy: 0.7162 Time: 36.14575  | Val Loss: 0.9161 Accuracy: 0.7213\n",
      "Epoch: 10/150 Train Loss: 0.8419 Accuracy: 0.7340 Time: 35.76139  | Val Loss: 0.8367 Accuracy: 0.7350\n",
      "Epoch: 11/150 Train Loss: 0.7693 Accuracy: 0.7567 Time: 36.23947  | Val Loss: 0.9065 Accuracy: 0.7131\n",
      "Epoch: 12/150 Train Loss: 0.7156 Accuracy: 0.7785 Time: 36.56012  | Val Loss: 1.0121 Accuracy: 0.7045\n",
      "Epoch: 13/150 Train Loss: 0.6518 Accuracy: 0.7978 Time: 36.07125  | Val Loss: 0.9570 Accuracy: 0.7185\n",
      "Epoch: 14/150 Train Loss: 0.6041 Accuracy: 0.8104 Time: 35.70141  | Val Loss: 0.7927 Accuracy: 0.7651\n",
      "Epoch: 15/150 Train Loss: 0.5724 Accuracy: 0.8203 Time: 35.22193  | Val Loss: 0.7134 Accuracy: 0.7829\n",
      "Epoch: 16/150 Train Loss: 0.5418 Accuracy: 0.8304 Time: 34.32799  | Val Loss: 0.6957 Accuracy: 0.7949\n",
      "Epoch: 17/150 Train Loss: 0.4939 Accuracy: 0.8401 Time: 34.35956  | Val Loss: 0.6395 Accuracy: 0.7987\n",
      "Epoch: 18/150 Train Loss: 0.4625 Accuracy: 0.8519 Time: 34.30803  | Val Loss: 0.7631 Accuracy: 0.7750\n",
      "Epoch: 19/150 Train Loss: 0.4466 Accuracy: 0.8601 Time: 34.59644  | Val Loss: 0.5783 Accuracy: 0.8176\n",
      "Epoch: 20/150 Train Loss: 0.3980 Accuracy: 0.8772 Time: 34.61515  | Val Loss: 0.6874 Accuracy: 0.7868\n",
      "Epoch: 21/150 Train Loss: 0.2847 Accuracy: 0.9091 Time: 34.36657  | Val Loss: 0.5767 Accuracy: 0.8329\n",
      "Epoch: 22/150 Train Loss: 0.2472 Accuracy: 0.9206 Time: 34.47630  | Val Loss: 0.5367 Accuracy: 0.8423\n",
      "Epoch: 23/150 Train Loss: 0.2223 Accuracy: 0.9279 Time: 35.20886  | Val Loss: 0.5293 Accuracy: 0.8522\n",
      "Epoch: 24/150 Train Loss: 0.2080 Accuracy: 0.9315 Time: 35.29115  | Val Loss: 0.6256 Accuracy: 0.8270\n",
      "Epoch: 25/150 Train Loss: 0.1929 Accuracy: 0.9380 Time: 34.68341  | Val Loss: 0.6091 Accuracy: 0.8301\n",
      "Epoch: 26/150 Train Loss: 0.1966 Accuracy: 0.9351 Time: 34.40939  | Val Loss: 0.6540 Accuracy: 0.8268\n",
      "Epoch: 27/150 Train Loss: 0.1740 Accuracy: 0.9453 Time: 34.90888  | Val Loss: 0.6176 Accuracy: 0.8410\n",
      "Epoch: 28/150 Train Loss: 0.1591 Accuracy: 0.9487 Time: 35.21354  | Val Loss: 0.6410 Accuracy: 0.8392\n",
      "Epoch: 29/150 Train Loss: 0.1470 Accuracy: 0.9528 Time: 34.75512  | Val Loss: 0.6155 Accuracy: 0.8443\n",
      "Epoch: 30/150 Train Loss: 0.1485 Accuracy: 0.9516 Time: 35.40896  | Val Loss: 0.5983 Accuracy: 0.8392\n",
      "Epoch: 31/150 Train Loss: 0.1382 Accuracy: 0.9527 Time: 34.61559  | Val Loss: 0.5918 Accuracy: 0.8469\n",
      "Epoch: 32/150 Train Loss: 0.1484 Accuracy: 0.9521 Time: 34.73324  | Val Loss: 0.7000 Accuracy: 0.8290\n",
      "Epoch: 33/150 Train Loss: 0.1242 Accuracy: 0.9586 Time: 35.57074  | Val Loss: 0.6283 Accuracy: 0.8362\n",
      "Epoch: 34/150 Train Loss: 0.1233 Accuracy: 0.9616 Time: 35.24597  | Val Loss: 0.7100 Accuracy: 0.8290\n",
      "Epoch: 35/150 Train Loss: 0.1156 Accuracy: 0.9615 Time: 35.16128  | Val Loss: 0.6562 Accuracy: 0.8428\n",
      "Epoch: 36/150 Train Loss: 0.1145 Accuracy: 0.9628 Time: 36.54743  | Val Loss: 0.6711 Accuracy: 0.8390\n",
      "Epoch: 37/150 Train Loss: 0.0958 Accuracy: 0.9690 Time: 35.91109  | Val Loss: 0.6965 Accuracy: 0.8415\n",
      "Epoch: 38/150 Train Loss: 0.1171 Accuracy: 0.9634 Time: 36.47686  | Val Loss: 0.7388 Accuracy: 0.8326\n",
      "Epoch: 39/150 Train Loss: 0.1056 Accuracy: 0.9651 Time: 36.11868  | Val Loss: 0.6475 Accuracy: 0.8494\n",
      "Epoch: 40/150 Train Loss: 0.0807 Accuracy: 0.9738 Time: 35.30998  | Val Loss: 0.6929 Accuracy: 0.8354\n",
      "Epoch: 41/150 Train Loss: 0.0469 Accuracy: 0.9836 Time: 35.32043  | Val Loss: 0.6620 Accuracy: 0.8609\n",
      "Epoch: 42/150 Train Loss: 0.0325 Accuracy: 0.9909 Time: 35.31182  | Val Loss: 0.7192 Accuracy: 0.8573\n",
      "Epoch: 43/150 Train Loss: 0.0298 Accuracy: 0.9905 Time: 34.79335  | Val Loss: 0.7662 Accuracy: 0.8555\n",
      "Epoch: 44/150 Train Loss: 0.0306 Accuracy: 0.9897 Time: 34.42533  | Val Loss: 0.7005 Accuracy: 0.8581\n",
      "Epoch: 45/150 Train Loss: 0.0299 Accuracy: 0.9901 Time: 35.11478  | Val Loss: 0.7929 Accuracy: 0.8591\n",
      "Epoch: 46/150 Train Loss: 0.0235 Accuracy: 0.9923 Time: 34.95982  | Val Loss: 0.7875 Accuracy: 0.8568\n",
      "Epoch: 47/150 Train Loss: 0.0217 Accuracy: 0.9932 Time: 34.46689  | Val Loss: 0.7036 Accuracy: 0.8637\n",
      "Epoch: 48/150 Train Loss: 0.0227 Accuracy: 0.9926 Time: 34.84617  | Val Loss: 0.7362 Accuracy: 0.8583\n",
      "Early stop at epoch 48!\n"
     ]
    }
   ],
   "source": [
    "net = vgg19_bn(num_classes=10, dropout=0.5)\n",
    "net = net.to(DEVICE)\n",
    "criteria = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.005, momentum=0.9)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(\n",
    "    optimizer, step_size=20, gamma=0.5, last_epoch=-1\n",
    ")\n",
    "early_stopper = EarlyStoppingInMem(patience=25, verbose=False)\n",
    "max_epochs = 150\n",
    "_ = naive_train_classification_model(\n",
    "    net,\n",
    "    criteria,\n",
    "    max_epochs,\n",
    "    train_dataloader,\n",
    "    val_dataloader,\n",
    "    DEVICE,\n",
    "    optimizer,\n",
    "    scheduler,\n",
    "    early_stopper,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4042f8e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch-cu124",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
